# Script to plot RGB images of the FIELDMAPS targets

# REMINDER: Python arrays are inverted [y,x] relative to IDL [x,y]

# =============================================================================
# Package dependencies
# =============================================================================

from aplpy import FITSFigure, make_rgb_cube, make_rgb_image
from astropy.io import fits
from astropy.wcs import WCS
from lmfit import Model
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import math
import numpy as np
from scipy import stats, optimize


# =============================================================================
# Magnetic field map on column density
# =============================================================================

def RGBmap(RGB_fits, RGB_png, ScaleLength, Region, FigSize, FilLabel=None, FilLabelX=0.5, FilLabelY=0.95, Scalebar='top right'):
    # Creating the RGB plot from the FITS and PNG files
    RGB_plot = FITSFigure(RGB_fits, figsize=(FigSize[0], FigSize[1]))
    RGB_plot.show_rgb(RGB_png)
    RGB_plot.tick_labels.set_xformat('dd.d')
    RGB_plot.tick_labels.set_yformat('d.dd')
    # Adding scale bar
    RGB_plot.add_scalebar(ScaleLength, '1 pc', corner=Scalebar, frame=True) # Scalebar equivalent for 1 pc 
    # Resizing the image
    RGB_plot.recenter(x=Region[0], y=Region[1], width=Region[2], height=Region[3])
    # Adding label for the filament
    if FilLabel != None:
        RGB_plot.add_label(FilLabelX, FilLabelY, FilLabel, relative=True, size=13, color='white')
    return RGB_plot

# =============================================================================
# Adding HAWC+ intensity contours to RGB map
# =============================================================================
def RGB_Add_IContours(RGB_plot, Iref, Ilevels, HAWCBeam=0.00506, Alpha=1.0):
    # Adding total intensity contour
    RGB_plot.show_contour(data=Iref, colors='white', levels=Ilevels, alpha=Alpha)
    RGB_plot.add_beam(facecolor='white', edgecolor='white',
                    linewidth=2, pad=1, corner='bottom left') # HAWC+ Beam
    RGB_plot.beam.set_major(HAWCBeam) # Update major axis for HAWC+ Beam
    RGB_plot.beam.set_minor(HAWCBeam) # Update minor axis for HAWC+ Beam
    return RGB_plot

# =============================================================================
# Adding Herschel Column Density contours to RGB map
# =============================================================================
def RGB_Add_NContours(RGB_plot, Nref, Nlevels, HerschelBeam=0.01011, Alpha=1.0, HighCont='white', LowCont='orange'):
    # Filtering the lowest level 
    HighLevels = Nlevels.copy()
    LowestLevel = [HighLevels.pop(0)]
    # Adding total intensity contour
    RGB_plot.show_contour(data=Nref, colors=HighCont, levels=HighLevels, alpha=Alpha)
    RGB_plot.show_contour(data=Nref, colors=LowCont, levels=LowestLevel, alpha=Alpha)
    RGB_plot.add_beam(facecolor=HighCont, edgecolor=HighCont,
                    linewidth=2, pad=1, corner='bottom left') # HAWC+ Beam
    RGB_plot.beam.set_major(HerschelBeam) # Update major axis for HAWC+ Beam
    RGB_plot.beam.set_minor(HerschelBeam) # Update minor axis for HAWC+ Beam
    return RGB_plot

# =============================================================================
# Adding HAWC+ vectors to RGB map
# =============================================================================
def RGB_Add_HAWC(RGB_plot, Pref, StepVec=3, ScaleVec=2.0, Color='red', Linewidth=2.0, IdI=10.0, PdP=3.0, Pmax=30.0, Alpha=1.0):
    # Creating a new fits object for APLpy's FITSFigure method to recognize
    # for the polarization fraction P
    pref = fits.PrimaryHDU(data=Pref[8].data, header=Pref[0].header)
    pmap = pref.copy()
    # for the magnetic field angle B
    oref = fits.PrimaryHDU(data=Pref[11].data, header=Pref[0].header)
    omap = oref.copy()

    # Creating a mask to hide polarization vectors with low signal-to-noise ratios
    imask_01 = np.where(Pref[0].data/Pref[1].data < IdI) # Total intensity SNR threshold
    imask_02 = np.where(Pref[0].data < 0) # Total intensity positive threshold
    pmask = np.where(Pref[8].data/Pref[9].data < PdP) # Polarization SNR threshold
    pmaxmask = np.where(Pref[8].data > Pmax) # Polarization SNR threshold
    # Forcing the polarization vectors to share the same amplitude
    pmap.data[np.where(Pref[8].data > 0.0)] = 1.0

    # Masking all the indices for  which the selection criteria failed
    pmap.data[imask_01] = np.nan
    pmap.data[imask_02] = np.nan
    pmap.data[pmask] = np.nan
    pmap.data[pmaxmask] = np.nan

    # Plotting the polarization vectors
    RGB_plot.show_vectors(pmap, omap, scale=ScaleVec, step=StepVec, color=Color, zorder=3, alpha=Alpha, linewidth=Linewidth)
    return RGB_plot

# =============================================================================
# Adding Planck vectors to RGB map
# =============================================================================
def RGB_Add_Planck(RGB_plot, Planck, StepVec=3, ScaleVec=2.0, Color='blue', Linewidth=2.0, Alpha=1.0):
    # Creating fits containers
    pmap = Planck[4].copy()
    pmap.data[:,:] = 1.0 # Force length of vectors to be identical
    bmap = Planck[6].copy()
    # Plotting the polarization vectors
    RGB_plot.show_vectors(pmap, bmap, scale=ScaleVec, step=StepVec, color=Color, alpha=Alpha, linewidth=Linewidth)
    # Returning the updated plot (not required, only for clarity)
    return RGB_plot